{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pickle\n",
    "import csv\n",
    "import os\n",
    "import torch\n",
    "from torch_geometric.data import Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../input/yoochoose-clicks.p','rb') as f:\n",
    "    df = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>session_id</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>item_id</th>\n",
       "      <th>category</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>2014-04-07T10:51:09.277Z</td>\n",
       "      <td>214536502</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>2014-04-07T10:54:09.868Z</td>\n",
       "      <td>214536500</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>2014-04-07T10:54:46.998Z</td>\n",
       "      <td>214536506</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>2014-04-07T10:57:00.306Z</td>\n",
       "      <td>214577561</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2</td>\n",
       "      <td>2014-04-07T13:56:37.614Z</td>\n",
       "      <td>214662742</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   session_id                 timestamp    item_id category\n",
       "0           1  2014-04-07T10:51:09.277Z  214536502        0\n",
       "1           1  2014-04-07T10:54:09.868Z  214536500        0\n",
       "2           1  2014-04-07T10:54:46.998Z  214536506        0\n",
       "3           1  2014-04-07T10:57:00.306Z  214577561        0\n",
       "4           2  2014-04-07T13:56:37.614Z  214662742        0"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>session_id</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>item_id</th>\n",
       "      <th>price</th>\n",
       "      <th>quantity</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>420374</td>\n",
       "      <td>2014-04-06T18:44:58.314Z</td>\n",
       "      <td>214537888</td>\n",
       "      <td>12462</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>420374</td>\n",
       "      <td>2014-04-06T18:44:58.325Z</td>\n",
       "      <td>214537850</td>\n",
       "      <td>10471</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>281626</td>\n",
       "      <td>2014-04-06T09:40:13.032Z</td>\n",
       "      <td>214535653</td>\n",
       "      <td>1883</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>420368</td>\n",
       "      <td>2014-04-04T06:13:28.848Z</td>\n",
       "      <td>214530572</td>\n",
       "      <td>6073</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>420368</td>\n",
       "      <td>2014-04-04T06:13:28.858Z</td>\n",
       "      <td>214835025</td>\n",
       "      <td>2617</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>140806</td>\n",
       "      <td>2014-04-07T09:22:28.132Z</td>\n",
       "      <td>214668193</td>\n",
       "      <td>523</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>140806</td>\n",
       "      <td>2014-04-07T09:22:28.176Z</td>\n",
       "      <td>214587399</td>\n",
       "      <td>1046</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>140806</td>\n",
       "      <td>2014-04-07T09:22:28.219Z</td>\n",
       "      <td>214586690</td>\n",
       "      <td>837</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>140806</td>\n",
       "      <td>2014-04-07T09:22:28.268Z</td>\n",
       "      <td>214774667</td>\n",
       "      <td>1151</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>140806</td>\n",
       "      <td>2014-04-07T09:22:28.280Z</td>\n",
       "      <td>214578823</td>\n",
       "      <td>1046</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>11</td>\n",
       "      <td>2014-04-03T11:04:11.417Z</td>\n",
       "      <td>214821371</td>\n",
       "      <td>1046</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>11</td>\n",
       "      <td>2014-04-03T11:04:18.097Z</td>\n",
       "      <td>214821371</td>\n",
       "      <td>1046</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>12</td>\n",
       "      <td>2014-04-02T10:42:17.227Z</td>\n",
       "      <td>214717867</td>\n",
       "      <td>1778</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>489758</td>\n",
       "      <td>2014-04-06T09:59:52.422Z</td>\n",
       "      <td>214826955</td>\n",
       "      <td>1360</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>489758</td>\n",
       "      <td>2014-04-06T09:59:52.476Z</td>\n",
       "      <td>214826715</td>\n",
       "      <td>732</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>489758</td>\n",
       "      <td>2014-04-06T09:59:52.578Z</td>\n",
       "      <td>214827026</td>\n",
       "      <td>1046</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>140802</td>\n",
       "      <td>2014-04-02T16:52:21.678Z</td>\n",
       "      <td>214716984</td>\n",
       "      <td>1883</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>489756</td>\n",
       "      <td>2014-04-05T16:51:59.947Z</td>\n",
       "      <td>214716932</td>\n",
       "      <td>6177</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>420378</td>\n",
       "      <td>2014-04-03T07:41:02.566Z</td>\n",
       "      <td>214821017</td>\n",
       "      <td>1046</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>420378</td>\n",
       "      <td>2014-04-03T07:41:02.616Z</td>\n",
       "      <td>214821020</td>\n",
       "      <td>732</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    session_id                 timestamp    item_id  price  quantity\n",
       "0       420374  2014-04-06T18:44:58.314Z  214537888  12462         1\n",
       "1       420374  2014-04-06T18:44:58.325Z  214537850  10471         1\n",
       "2       281626  2014-04-06T09:40:13.032Z  214535653   1883         1\n",
       "3       420368  2014-04-04T06:13:28.848Z  214530572   6073         1\n",
       "4       420368  2014-04-04T06:13:28.858Z  214835025   2617         1\n",
       "5       140806  2014-04-07T09:22:28.132Z  214668193    523         1\n",
       "6       140806  2014-04-07T09:22:28.176Z  214587399   1046         1\n",
       "7       140806  2014-04-07T09:22:28.219Z  214586690    837         1\n",
       "8       140806  2014-04-07T09:22:28.268Z  214774667   1151         1\n",
       "9       140806  2014-04-07T09:22:28.280Z  214578823   1046         1\n",
       "10          11  2014-04-03T11:04:11.417Z  214821371   1046         1\n",
       "11          11  2014-04-03T11:04:18.097Z  214821371   1046         1\n",
       "12          12  2014-04-02T10:42:17.227Z  214717867   1778         4\n",
       "13      489758  2014-04-06T09:59:52.422Z  214826955   1360         2\n",
       "14      489758  2014-04-06T09:59:52.476Z  214826715    732         2\n",
       "15      489758  2014-04-06T09:59:52.578Z  214827026   1046         1\n",
       "16      140802  2014-04-02T16:52:21.678Z  214716984   1883         1\n",
       "17      489756  2014-04-05T16:51:59.947Z  214716932   6177         1\n",
       "18      420378  2014-04-03T07:41:02.566Z  214821017   1046         1\n",
       "19      420378  2014-04-03T07:41:02.616Z  214821020    732         1"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "buy_df = pd.read_csv('../input/yoochoose-buys.dat', header=None)\n",
    "buy_df.columns=['session_id','timestamp','item_id','price','quantity']\n",
    "buy_df.head(20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "session_id     509696\n",
       "timestamp     1136477\n",
       "item_id         19949\n",
       "price             735\n",
       "quantity           28\n",
       "dtype: int64"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "buy_df.nunique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "session_id     9249729\n",
       "timestamp     32937845\n",
       "item_id          52739\n",
       "category           340\n",
       "dtype: int64"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.nunique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "session_id     4431931\n",
       "timestamp     24590089\n",
       "item_id          48255\n",
       "category           331\n",
       "dtype: int64"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['valid_session'] = df.session_id.map(df.groupby('session_id')['item_id'].size() > 2)\n",
    "df = df.loc[df.valid_session].drop('valid_session',axis=1)\n",
    "df.nunique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "session_id    1000000\n",
       "timestamp     5546275\n",
       "item_id         37353\n",
       "category          260\n",
       "dtype: int64"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# #randomly sample a couple of them\n",
    "sampled_session_id = np.random.choice(df.session_id.unique(), 1000000, replace=False)\n",
    "df = df.loc[df.session_id.isin(sampled_session_id)]\n",
    "df.nunique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "session_id    0\n",
       "timestamp     0\n",
       "item_id       0\n",
       "category      0\n",
       "dtype: int64"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.isna().sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5.548274"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# average length of session \n",
    "df.groupby('session_id')['item_id'].size().mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>session_id</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>item_id</th>\n",
       "      <th>category</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>3</td>\n",
       "      <td>2014-04-02T13:17:46.940Z</td>\n",
       "      <td>21302</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>3</td>\n",
       "      <td>2014-04-02T13:26:02.515Z</td>\n",
       "      <td>25022</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>3</td>\n",
       "      <td>2014-04-02T13:30:12.318Z</td>\n",
       "      <td>29039</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49</th>\n",
       "      <td>19</td>\n",
       "      <td>2014-04-01T20:52:12.357Z</td>\n",
       "      <td>5749</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50</th>\n",
       "      <td>19</td>\n",
       "      <td>2014-04-01T20:52:13.758Z</td>\n",
       "      <td>5749</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    session_id                 timestamp  item_id category\n",
       "10           3  2014-04-02T13:17:46.940Z    21302        0\n",
       "11           3  2014-04-02T13:26:02.515Z    25022        0\n",
       "12           3  2014-04-02T13:30:12.318Z    29039        0\n",
       "49          19  2014-04-01T20:52:12.357Z     5749        0\n",
       "50          19  2014-04-01T20:52:13.758Z     5749        0"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "item_encoder = LabelEncoder()\n",
    "df['item_id'] = item_encoder.fit_transform(df.item_id)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>session_id</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>item_id</th>\n",
       "      <th>category</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>3</td>\n",
       "      <td>2014-04-02T13:17:46.940Z</td>\n",
       "      <td>21302</td>\n",
       "      <td>0</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>3</td>\n",
       "      <td>2014-04-02T13:26:02.515Z</td>\n",
       "      <td>25022</td>\n",
       "      <td>0</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>3</td>\n",
       "      <td>2014-04-02T13:30:12.318Z</td>\n",
       "      <td>29039</td>\n",
       "      <td>0</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49</th>\n",
       "      <td>19</td>\n",
       "      <td>2014-04-01T20:52:12.357Z</td>\n",
       "      <td>5749</td>\n",
       "      <td>0</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50</th>\n",
       "      <td>19</td>\n",
       "      <td>2014-04-01T20:52:13.758Z</td>\n",
       "      <td>5749</td>\n",
       "      <td>0</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    session_id                 timestamp  item_id category  label\n",
       "10           3  2014-04-02T13:17:46.940Z    21302        0  False\n",
       "11           3  2014-04-02T13:26:02.515Z    25022        0  False\n",
       "12           3  2014-04-02T13:30:12.318Z    29039        0  False\n",
       "49          19  2014-04-01T20:52:12.357Z     5749        0  False\n",
       "50          19  2014-04-01T20:52:13.758Z     5749        0  False"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['label'] = df.session_id.isin(buy_df.session_id)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.085507"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.drop_duplicates('session_id')['label'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch_geometric.data import InMemoryDataset\n",
    "from tqdm import tqdm\n",
    "\n",
    "class YooChooseBinaryDataset(InMemoryDataset):\n",
    "    def __init__(self, root, transform=None, pre_transform=None):\n",
    "        super(YooChooseBinaryDataset, self).__init__(root, transform, pre_transform)\n",
    "        self.data, self.slices = torch.load(self.processed_paths[0])\n",
    "\n",
    "    @property\n",
    "    def raw_file_names(self):\n",
    "        return []\n",
    "    @property\n",
    "    def processed_file_names(self):\n",
    "        return ['../input/yoochoose_click_binary_1M_sess.dataset']\n",
    "\n",
    "    def download(self):\n",
    "        pass\n",
    "    \n",
    "    def process(self):\n",
    "        \n",
    "        data_list = []\n",
    "\n",
    "        # process by session_id\n",
    "        grouped = df.groupby('session_id')\n",
    "        for session_id, group in tqdm(grouped):\n",
    "            sess_item_id = LabelEncoder().fit_transform(group.item_id)\n",
    "            group = group.reset_index(drop=True)\n",
    "            group['sess_item_id'] = sess_item_id\n",
    "            node_features = group.loc[group.session_id==session_id,['sess_item_id','item_id']].sort_values('sess_item_id').item_id.drop_duplicates().values\n",
    "\n",
    "            node_features = torch.LongTensor(node_features).unsqueeze(1)\n",
    "            target_nodes = group.sess_item_id.values[1:]\n",
    "            source_nodes = group.sess_item_id.values[:-1]\n",
    "\n",
    "            edge_index = torch.tensor([source_nodes,\n",
    "                                   target_nodes], dtype=torch.long)\n",
    "            x = node_features\n",
    "\n",
    "            y = torch.FloatTensor([group.label.values[0]])\n",
    "\n",
    "            data = Data(x=x, edge_index=edge_index, y=y)\n",
    "            data_list.append(data)\n",
    "        \n",
    "        data, slices = self.collate(data_list)\n",
    "        torch.save((data, slices), self.processed_paths[0])\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = YooChooseBinaryDataset(root='../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(800000, 100000, 100000)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = dataset.shuffle()\n",
    "train_dataset = dataset[:800000]\n",
    "val_dataset = dataset[800000:900000]\n",
    "test_dataset = dataset[900000:]\n",
    "len(train_dataset), len(val_dataset), len(test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.data import DataLoader\n",
    "batch_size= 1024\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size)\n",
    "val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
    "test_loader = DataLoader(test_dataset, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "37353"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_items = df.item_id.max() +1 \n",
    "num_items "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.nn import Sequential as Seq, Linear, ReLU\n",
    "from torch_geometric.nn import MessagePassing\n",
    "from torch_geometric.utils import remove_self_loops, add_self_loops\n",
    "class SAGEConv(MessagePassing):\n",
    "    def __init__(self, in_channels, out_channels):\n",
    "        super(SAGEConv, self).__init__(aggr='max') #  \"Max\" aggregation.\n",
    "        self.lin = torch.nn.Linear(in_channels, out_channels)\n",
    "        self.act = torch.nn.ReLU()\n",
    "        self.update_lin = torch.nn.Linear(in_channels + out_channels, in_channels, bias=False)\n",
    "        self.update_act = torch.nn.ReLU()\n",
    "        \n",
    "    def forward(self, x, edge_index):\n",
    "        # x has shape [N, in_channels]\n",
    "        # edge_index has shape [2, E]\n",
    "        \n",
    "        \n",
    "        edge_index, _ = remove_self_loops(edge_index)\n",
    "        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))\n",
    "        \n",
    "        \n",
    "        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)\n",
    "\n",
    "    def message(self, x_j):\n",
    "        # x_j has shape [E, in_channels]\n",
    "\n",
    "        x_j = self.lin(x_j)\n",
    "        x_j = self.act(x_j)\n",
    "        \n",
    "        return x_j\n",
    "\n",
    "    def update(self, aggr_out, x):\n",
    "        # aggr_out has shape [N, out_channels]\n",
    "\n",
    "\n",
    "        new_embedding = torch.cat([aggr_out, x], dim=1)\n",
    "        \n",
    "        new_embedding = self.update_lin(new_embedding)\n",
    "        new_embedding = self.update_act(new_embedding)\n",
    "        \n",
    "        return new_embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "embed_dim = 128\n",
    "from torch_geometric.nn import GraphConv, TopKPooling, GatedGraphConv\n",
    "from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp\n",
    "import torch.nn.functional as F\n",
    "class Net(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "\n",
    "        self.conv1 = SAGEConv(embed_dim, 128)\n",
    "        self.pool1 = TopKPooling(128, ratio=0.8)\n",
    "        self.conv2 = SAGEConv(128, 128)\n",
    "        self.pool2 = TopKPooling(128, ratio=0.8)\n",
    "        self.conv3 = SAGEConv(128, 128)\n",
    "        self.pool3 = TopKPooling(128, ratio=0.8)\n",
    "        self.item_embedding = torch.nn.Embedding(num_embeddings=df.item_id.max() +1, embedding_dim=embed_dim)\n",
    "        self.lin1 = torch.nn.Linear(256, 128)\n",
    "        self.lin2 = torch.nn.Linear(128, 64)\n",
    "        self.lin3 = torch.nn.Linear(64, 1)\n",
    "        self.bn1 = torch.nn.BatchNorm1d(128)\n",
    "        self.bn2 = torch.nn.BatchNorm1d(64)\n",
    "        self.act1 = torch.nn.ReLU()\n",
    "        self.act2 = torch.nn.ReLU()        \n",
    "  \n",
    "    def forward(self, data):\n",
    "        x, edge_index, batch = data.x, data.edge_index, data.batch\n",
    "        x = self.item_embedding(x)\n",
    "        x = x.squeeze(1)        \n",
    "\n",
    "        x = F.relu(self.conv1(x, edge_index))\n",
    "\n",
    "        x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch)\n",
    "        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n",
    "\n",
    "        x = F.relu(self.conv2(x, edge_index))\n",
    "     \n",
    "        x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch)\n",
    "        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n",
    "\n",
    "        x = F.relu(self.conv3(x, edge_index))\n",
    "\n",
    "        x, edge_index, _, batch, _ = self.pool3(x, edge_index, None, batch)\n",
    "        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n",
    "\n",
    "        x = x1 + x2 + x3\n",
    "\n",
    "        x = self.lin1(x)\n",
    "        x = self.act1(x)\n",
    "        x = self.lin2(x)\n",
    "        x = self.act2(x)      \n",
    "        x = F.dropout(x, p=0.5, training=self.training)\n",
    "\n",
    "        x = torch.sigmoid(self.lin3(x)).squeeze(1)\n",
    "\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "device = torch.device('cuda')\n",
    "model = Net().to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.005)\n",
    "crit = torch.nn.BCELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "\n",
    "    loss_all = 0\n",
    "    for data in train_loader:\n",
    "        data = data.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        label = data.y.to(device)\n",
    "        loss = crit(output, label)\n",
    "        loss.backward()\n",
    "        loss_all += data.num_graphs * loss.item()\n",
    "        optimizer.step()\n",
    "    return loss_all / len(train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import roc_auc_score\n",
    "def evaluate(loader):\n",
    "    model.eval()\n",
    "\n",
    "    predictions = []\n",
    "    labels = []\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for data in loader:\n",
    "\n",
    "            data = data.to(device)\n",
    "            pred = model(data).detach().cpu().numpy()\n",
    "\n",
    "            label = data.y.detach().cpu().numpy()\n",
    "            predictions.append(pred)\n",
    "            labels.append(label)\n",
    "\n",
    "    predictions = np.hstack(predictions)\n",
    "    labels = np.hstack(labels)\n",
    "    \n",
    "    return roc_auc_score(labels, predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000, Loss: 0.28378, Train Auc: 0.76702, Val Auc: 0.73039, Test Auc: 0.72918\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(1):\n",
    "    loss = train()\n",
    "    train_acc = evaluate(train_loader)\n",
    "    val_acc = evaluate(val_loader)    \n",
    "    test_acc = evaluate(test_loader)\n",
    "    print('Epoch: {:03d}, Loss: {:.5f}, Train Auc: {:.5f}, Val Auc: {:.5f}, Test Auc: {:.5f}'.\n",
    "          format(epoch, loss, train_acc, val_acc, test_acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
